Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support IP-Adapter Plus #5915

Merged
merged 18 commits into from
Dec 4, 2023

Conversation

okotaku
Copy link
Contributor

@okotaku okotaku commented Nov 24, 2023

What does this PR do?

import torch
from diffusers import DiffusionPipeline
from diffusers.utils import load_image

prompt = ''

pipe = DiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    safety_checker=None,
    requires_safety_checker=False,
    torch_dtype=torch.float16
    )
pipe.to('cuda')
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")

image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")

image = pipe(
    prompt,
    ip_adapter_image=image,
    height=512,
    width=512,
).images[0]
image.save('demo.png')

demo

import torch
from diffusers import DiffusionPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection


prompt = ''

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter",
    subfolder="models/image_encoder",
    torch_dtype=torch.float16,
).to('cuda')
vae = AutoencoderKL.from_pretrained(
    'madebyollin/sdxl-vae-fp16-fix',
    torch_dtype=torch.float16,
)
pipe = DiffusionPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    image_encoder=image_encoder,
    vae=vae,
    safety_checker=None,
    requires_safety_checker=False,
    torch_dtype=torch.float16
    )
pipe.to('cuda')
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.bin")

image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")

image = pipe(
    prompt,
    ip_adapter_image=image,
    height=1024,
    width=1024,
).images[0]
image.save('demo.png')

demo_2

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
Copy link
Collaborator

oh thanks so much @okotaku

the first example (sd.15) does not look nice- is the result expected?

@okotaku
Copy link
Contributor Author

okotaku commented Nov 27, 2023

@yiyixuxu I think that the quality of IP-Adapter-Plus-XL is promising, but IP-Adapter-Plus-SDv1.5 is not good.

@okotaku
Copy link
Contributor Author

okotaku commented Nov 27, 2023

@yiyixuxu Other example of IP-Adapter-Plus-SDv1.5.

import torch
from diffusers import DiffusionPipeline
from diffusers.utils import load_image

prompt = ''

pipe = DiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    safety_checker=None,
    requires_safety_checker=False,
    torch_dtype=torch.float16
    )
pipe.to('cuda')
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")

#image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
image = load_image("https://github.com/huggingface/diffusers/assets/24734142/6f15c2c6-7a78-43c9-9985-058e895c64f2")

image = pipe(
    prompt,
    ip_adapter_image=image,
    height=512,
    width=512,
).images[0]
image.save('demo.png')

Input

input

Output

demo

@alexblattner
Copy link

the plus model doesn't seem to work....

Traceback (most recent call last):
  File "/mecomics-api/rubberUse/predict.py", line 269, in <module>
    pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin",cache_dir="model_cache")
  File "/home/alexblattnershalom/.local/lib/python3.10/site-packages/diffusers/loaders/ip_adapter.py", line 152, in load_ip_adapter
    self.unet._load_ip_adapter_weights(state_dict)
  File "/home/alexblattnershalom/.local/lib/python3.10/site-packages/diffusers/loaders/unet.py", line 711, in _load_ip_adapter_weights
    clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
KeyError: 'proj.weight'
pipe = StableDiffusionPipeline.from_single_file(
        "./poselabs.safetensors",
        cache_dir="model_cache",
        local_file_only=True
    )
    pipe.to(torch_device="cuda", torch_dtype=torch.float16)
    pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin",cache_dir="model_cache")
    generator = torch.Generator(device="cpu").manual_seed(33)
    images=pipe(prompt="best quality, high quality, at the beach",ip_adapter_image=image1,generator=generator,num_images_per_prompt=4,num_inference_steps=20).images

@yiyixuxu
Copy link
Collaborator

@okotaku

the official example here with for ip-adapter plus face seems ok and it use sd1.5 too - can we try this example? https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter-plus-face_demo.ipynb

also can we try a multimodal prompts example? here https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter-plus_demo.ipynb

thanks!

YiYi

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! @okotaku
I left my feedbacks but I will wait @patrickvonplaten to have review before making any changes :)

Comment on lines 455 to 467
if isinstance(self.unet.encoder_hid_proj, ImageProjection):
# IP-Adapter
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
else:
# IP-Adapter Plus
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a argument return_hidden_states to encode_image method? I want to keep encode_image method generic and not specific to ip-adapter.

Also updated the variable names so it is clear whether we are return image embedding or the hidden states

Suggested change
if isinstance(self.unet.encoder_hid_proj, ImageProjection):
# IP-Adapter
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
else:
# IP-Adapter Plus
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_embeds = uncond_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if not return_hidden_states:
# IP-Adapter
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
else:
# IP-Adapter Plus
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
return image_enc_hidden_states, uncond_image_enc_hidden_states

@@ -790,3 +790,155 @@ def forward(self, caption, force_drop_ids=None):
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class PerceiverAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we refactor this using Attention?
cc @patrickvonplaten

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would be important to use the attention class here IMO to make sure it can be used with torch's scale-dot product attention

Copy link
Contributor Author

@okotaku okotaku Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten @yiyixuxu
When I refactored PerceiverAttention, I noticed that there are still some differences between out and out2. What are your thoughts on this?

https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py#L72

import torch
import math
import torch.nn.functional as F

scale = 1 / math.sqrt(64)
scale2 = 1 / math.sqrt(math.sqrt(64))

query = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
key = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
value = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float16)
out1 = F.scaled_dot_product_attention(query,key,value, scale=scale)

weight = (query * scale2) @ (key * scale2).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out2 = weight @ value
print(torch.allclose(out1, out2, atol=1e-4))

weight2 = query @ key.transpose(-2, -1) * scale
weight2 = torch.softmax(weight2.float(), dim=-1).type(weight2.dtype)
out3 = weight2 @ value
print(torch.allclose(out1, out3, atol=1e-4))
print(torch.allclose(out2, out3, atol=1e-4))

print(torch.abs(out1 - out2).sum())
print(torch.abs(out3 - out2).sum())
print(torch.abs(out1 - out3).sum())

---
False
False
False
tensor(0.1144, device='cuda:0', dtype=torch.float16)
tensor(0.0535, device='cuda:0', dtype=torch.float16)
tensor(0.1212, device='cuda:0', dtype=torch.float16)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch
import math
import torch.nn.functional as F

scale = 1 / math.sqrt(64)
scale2 = 1 / math.sqrt(math.sqrt(64))

query = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
key = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
value = torch.rand(1, 4, 8, 64, device="cuda", dtype=torch.float32)
out1 = F.scaled_dot_product_attention(query,key,value, scale=scale)

weight = (query * scale2) @ (key * scale2).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out2 = weight @ value
print(torch.allclose(out1, out2, atol=1e-4))

weight2 = query @ key.transpose(-2, -1) * scale
weight2 = torch.softmax(weight2.float(), dim=-1).type(weight2.dtype)
out3 = weight2 @ value
print(torch.allclose(out1, out3, atol=1e-4))
print(torch.allclose(out2, out3, atol=1e-4))

print(torch.abs(out1 - out2).sum())
print(torch.abs(out3 - out2).sum())
print(torch.abs(out1 - out3).sum())

---
True
True
True
tensor(0.0001, device='cuda:0')
tensor(4.6417e-05, device='cuda:0')
tensor(0.0001, device='cuda:0')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numerical difference for float32 is small, no? that means your implementation is most likely correct
for float16 can we try to see if the difference is less than 1e-3?
Also, let's generate some outputs with the refactored code? if the results look similar to before it should be fine!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output image is no problems.

demo

@yiyixuxu
Copy link
Collaborator

related to #5911

@okotaku
Copy link
Contributor Author

okotaku commented Nov 28, 2023

DDIMScheduler and stabilityai/sd-vae-ft-mse work here. Thank you @fabiorigano .

import torch
from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler
from diffusers.utils import load_image

prompt = 'photo of a beautiful girl wearing casual shirt in a garden'

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse",
                                    torch_dtype=torch.float16)
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
pipe = DiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    vae=vae,
    scheduler=noise_scheduler,
    safety_checker=None,
    requires_safety_checker=False,
    torch_dtype=torch.float16
    )
pipe.to('cuda')
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus-face_sd15.bin")

image = load_image("https://github.com/huggingface/diffusers/assets/24734142/cd1c16d0-8313-4e12-8fe1-25ad82ca6b1b")

image = pipe(
    prompt,
    ip_adapter_image=image,
    height=512,
    width=512,
).images[0]
image.save('demo.png')

Input

ai_face

Output

demo

@okotaku
Copy link
Contributor Author

okotaku commented Nov 28, 2023

import torch
from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler
from diffusers.utils import load_image

prompt = 'best quality, high quality, wearing sunglasses in a garden'

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse",
                                    torch_dtype=torch.float16)
noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
pipe = DiffusionPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V4.0_noVAE",
    vae=vae,
    scheduler=noise_scheduler,
    safety_checker=None,
    requires_safety_checker=False,
    torch_dtype=torch.float16
    )
pipe.to('cuda')
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
pipe.set_ip_adapter_scale(0.6)

image = load_image("https://github.com/huggingface/diffusers/assets/24734142/6f15c2c6-7a78-43c9-9985-058e895c64f2")

image = pipe(
    prompt,
    ip_adapter_image=image,
    height=512,
    width=512,
).images[0]
image.save('demo.png')

Input

input

Output

demo

@blx0102
Copy link

blx0102 commented Nov 29, 2023

Could this be merged? Looking forward to use this

@alexblattner
Copy link

@blx0102 me too. I am checking every day for this haha

@patrickvonplaten
Copy link
Contributor

Results look amazing here @okotaku! Would be great if we could try to re-use our fast Attention class for the perceiver attention

@yiyixuxu yiyixuxu mentioned this pull request Nov 30, 2023
5 tasks
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@@ -496,11 +496,22 @@ def encode_image(self, image, device, num_images_per_prompt):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
if self.image_encoder.config.output_hidden_states:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of changing the config for image_encoder let's:

  1. add a argument output_hidden_state to encode_image()
  2. inside the ip-adaper specific code in the pipelines (i.e. here ), you can use the same logic you used before to determine the value of output_hidden_state and pass it to encode_image(), e.g. output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True

essentially we just want to keep encode_image more generic and put ip-adapter specific log into the if ip_adapter_image is not None: ....

@@ -1526,6 +1559,9 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

if attn.concat_kv_input:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this step can be done outside of the attention processor

Comment on lines 1550 to 1552
if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.layer_norm is not None:
hidden_states = attn.layer_norm(hidden_states)

can we do this outside of the attention_processor too?

Comment on lines 1562 to 1563
if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if attn.concat_kv_input:
encoder_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=-2)

@okotaku
Copy link
Contributor Author

okotaku commented Dec 1, 2023

@yiyixuxu

I run it, but no changes were made.

➜  /workspace git:(ip_adapter_plus) ✗ make fix-copies
python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
➜  /workspace git:(ip_adapter_plus) ✗ 

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 1, 2023

cc @DN6 @patrickvonplaten
I also run make fix-copies and did not detect any changes - not sure why the quality test fail here

Comment on lines 851 to 859
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential:
"""Get feedforward network."""
inner_dim = int(embed_dims * ffn_ratio)
return nn.Sequential(
nn.LayerNorm(embed_dims),
nn.Linear(embed_dims, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, embed_dims, bias=False),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential:
"""Get feedforward network."""
inner_dim = int(embed_dims * ffn_ratio)
return nn.Sequential(
nn.LayerNorm(embed_dims),
nn.Linear(embed_dims, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, embed_dims, bias=False),
)
def _get_ffn(self, embed_dims, ffn_ratio=4) -> nn.Sequential:
"""Get feedforward network."""
inner_dim = int(embed_dims * ffn_ratio)
return nn.Sequential(
nn.LayerNorm(embed_dims),
nn.Linear(embed_dims, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, embed_dims, bias=False),
)

Can we replace this with a layer norm layer
first and then

class FeedForward(nn.Module):
?

@@ -489,18 +489,29 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

def encode_image(self, image, device, num_images_per_prompt):
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states):
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):

This is public API so we need to make sure it's backward compatible. Also ideally we should add docstrings here

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking almost ready to be merged just some minor comments. Let's try to re-use our Feed-forward class here :-)


image_projection.load_state_dict(image_proj_state_dict)
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now, but let's make sure to later factor this out with a conversion_... function later

Comment on lines +848 to +851
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),
nn.Sequential(
nn.LayerNorm(hidden_dims),
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
),

nice!

@patrickvonplaten patrickvonplaten merged commit 0a08d41 into huggingface:main Dec 4, 2023
20 checks passed
@patrickvonplaten
Copy link
Contributor

Great job @okotaku - super cool addition :-)

@okotaku
Copy link
Contributor Author

okotaku commented Dec 4, 2023

@patrickvonplaten @yiyixuxu Thank you for your reviews!

@alexblattner
Copy link

Hey, does it work with ip adapter plus yet?


clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
KeyError: 'proj.weight'

@vladmandic
Copy link
Contributor

fyi, some examples...imo, ip-adapter-plus produces best results so far...
xyz_grid-xyz_grid-0016-absolutereality_v1-beautiful woman wearing a gown in a city

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 7, 2023

@vladmandic I think the face models need to use ddim see here #5911 (comment)

@vladmandic
Copy link
Contributor

here's a different example with ddim at 50 steps - it works, but faces are still distorted an nowhere the quality of ip-adapter-plus
xyz_grid-xyz_grid-0017-realisticVisionV30_v30VAE-

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 7, 2023

cc @xiaohu2015 is this expected?

@xiaohu2015
Copy link
Contributor

hi , for face model, you should use a crop of face image.

@MikeHanKK
Copy link

It seems controlnet img2img does not support ip adapter plus. @yiyixuxu

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Support IP-Adapter Plus

* fix format

* restore before black format

* restore before black format

* generic

* Refactor PerceiverAttention

* format

* fix test and refactor PerceiverAttention

* generic encode_image

* keep attention implementation

* merge tests

* encode_image backward compatible

* code quality

* fix controlnet inpaint pipeline

* refactor FFN

* refactor FFN

---------

Co-authored-by: YiYi Xu <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Support IP-Adapter Plus

* fix format

* restore before black format

* restore before black format

* generic

* Refactor PerceiverAttention

* format

* fix test and refactor PerceiverAttention

* generic encode_image

* keep attention implementation

* merge tests

* encode_image backward compatible

* code quality

* fix controlnet inpaint pipeline

* refactor FFN

* refactor FFN

---------

Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants